{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+---+---------+------+----------+---------------+------------------+-----------------+------------+-----+------+------------+------------+--------------+--------------+------+\n",
      "|age|workclass|fnlwgt| education|educational-num|    marital-status|       occupation|relationship| race|gender|capital-gain|capital-loss|hours-per-week|native-country|income|\n",
      "+---+---------+------+----------+---------------+------------------+-----------------+------------+-----+------+------------+------------+--------------+--------------+------+\n",
      "| 25|  Private|226802|      11th|              7|     Never-married|Machine-op-inspct|   Own-child|Black|  Male|           0|           0|            40| United-States| <=50K|\n",
      "| 38|  Private| 89814|   HS-grad|              9|Married-civ-spouse|  Farming-fishing|     Husband|White|  Male|           0|           0|            50| United-States| <=50K|\n",
      "| 28|Local-gov|336951|Assoc-acdm|             12|Married-civ-spouse|  Protective-serv|     Husband|White|  Male|           0|           0|            40| United-States|  >50K|\n",
      "+---+---------+------+----------+---------------+------------------+-----------------+------------+-----+------+------------+------------+--------------+--------------+------+\n",
      "only showing top 3 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from pyspark.sql import SparkSession\n",
    "\n",
    "spark = SparkSession.builder.appName('adult_logReg').getOrCreate()\n",
    "df = spark.read.csv('adult.csv', inferSchema = True, header=True)\n",
    "df.show(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "cols = df.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "root\n",
      " |-- age: integer (nullable = true)\n",
      " |-- workclass: string (nullable = true)\n",
      " |-- fnlwgt: integer (nullable = true)\n",
      " |-- education: string (nullable = true)\n",
      " |-- educational-num: integer (nullable = true)\n",
      " |-- marital-status: string (nullable = true)\n",
      " |-- occupation: string (nullable = true)\n",
      " |-- relationship: string (nullable = true)\n",
      " |-- race: string (nullable = true)\n",
      " |-- gender: string (nullable = true)\n",
      " |-- capital-gain: integer (nullable = true)\n",
      " |-- capital-loss: integer (nullable = true)\n",
      " |-- hours-per-week: integer (nullable = true)\n",
      " |-- native-country: string (nullable = true)\n",
      " |-- income: string (nullable = true)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "df.printSchema()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.ml import Pipeline\n",
    "from pyspark.ml.feature import OneHotEncoderEstimator, StringIndexer, VectorAssembler\n",
    "\n",
    "categoricalColumns = [\"workclass\", \"education\", \"marital-status\", \"occupation\", \"relationship\", \"race\", \"gender\", \"native-country\"]\n",
    "stages = []\n",
    "\n",
    "for categoricalCol in categoricalColumns:\n",
    "    stringIndexer = StringIndexer(inputCol = categoricalCol, outputCol = categoricalCol + 'Index')\n",
    "    encoder = OneHotEncoderEstimator(inputCols=[stringIndexer.getOutputCol()], outputCols=[categoricalCol + \"classVec\"])\n",
    "    stages += [stringIndexer, encoder]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "label_stringIdx = StringIndexer(inputCol = 'income', outputCol = 'label')\n",
    "stages += [label_stringIdx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "numericCols = [\"age\", \"fnlwgt\", \"educational-num\", \"capital-gain\", \"capital-loss\", \"hours-per-week\"]\n",
    "assemblerInputs = [c + \"classVec\" for c in categoricalColumns] + numericCols\n",
    "assembler = VectorAssembler(inputCols=assemblerInputs, outputCol=\"features\")\n",
    "stages += [assembler]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-----+--------------------+---+---------+------+----------+---------------+------------------+-----------------+------------+-----+------+------------+------------+--------------+--------------+------+\n",
      "|label|            features|age|workclass|fnlwgt| education|educational-num|    marital-status|       occupation|relationship| race|gender|capital-gain|capital-loss|hours-per-week|native-country|income|\n",
      "+-----+--------------------+---+---------+------+----------+---------------+------------------+-----------------+------------+-----+------+------------+------------+--------------+--------------+------+\n",
      "|  0.0|(100,[0,13,24,35,...| 25|  Private|226802|      11th|              7|     Never-married|Machine-op-inspct|   Own-child|Black|  Male|           0|           0|            40| United-States| <=50K|\n",
      "|  0.0|(100,[0,8,23,39,4...| 38|  Private| 89814|   HS-grad|              9|Married-civ-spouse|  Farming-fishing|     Husband|White|  Male|           0|           0|            50| United-States| <=50K|\n",
      "|  1.0|(100,[2,14,23,41,...| 28|Local-gov|336951|Assoc-acdm|             12|Married-civ-spouse|  Protective-serv|     Husband|White|  Male|           0|           0|            40| United-States|  >50K|\n",
      "+-----+--------------------+---+---------+------+----------+---------------+------------------+-----------------+------------+-----+------+------------+------------+--------------+--------------+------+\n",
      "only showing top 3 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "pipeline = Pipeline(stages=stages)\n",
    "pipelineModel = pipeline.fit(df)\n",
    "df = pipelineModel.transform(df)\n",
    "selectedcols = [\"label\", \"features\"] + cols\n",
    "df = df.select(selectedcols)\n",
    "df.show(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DataFrame[label: double, features: vector, age: int, workclass: string, fnlwgt: int, education: string, educational-num: int, marital-status: string, occupation: string, relationship: string, race: string, gender: string, capital-gain: int, capital-loss: int, hours-per-week: int, native-country: string, income: string]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "display(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "34255\n",
      "14587\n"
     ]
    }
   ],
   "source": [
    "train, test = df.randomSplit([0.7, 0.3], seed=100)\n",
    "print(train.count())\n",
    "print(test.count())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Logistic Regression"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.ml.classification import LogisticRegression\n",
    "\n",
    "lr = LogisticRegression(labelCol = 'label', featuresCol = 'features', maxIter=10)\n",
    "lrModel = lr.fit(train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[Row(label=0.0, features=SparseVector(100, {0: 1.0, 8: 1.0, 23: 1.0, 29: 1.0, 43: 1.0, 48: 1.0, 52: 1.0, 53: 1.0, 94: 26.0, 95: 58426.0, 96: 9.0, 99: 50.0}), age=26, workclass='Private', fnlwgt=58426, education='HS-grad', educational-num=9, marital-status='Married-civ-spouse', occupation='Prof-specialty', relationship='Husband', race='White', gender='Male', capital-gain=0, capital-loss=0, hours-per-week=50, native-country='United-States', income='<=50K', rawPrediction=DenseVector([0.8176, -0.8176]), probability=DenseVector([0.6937, 0.3063]), prediction=0.0)]"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predictions = lrModel.transform(test)\n",
    "predictions.take(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "root\n",
      " |-- label: double (nullable = false)\n",
      " |-- features: vector (nullable = true)\n",
      " |-- age: integer (nullable = true)\n",
      " |-- workclass: string (nullable = true)\n",
      " |-- fnlwgt: integer (nullable = true)\n",
      " |-- education: string (nullable = true)\n",
      " |-- educational-num: integer (nullable = true)\n",
      " |-- marital-status: string (nullable = true)\n",
      " |-- occupation: string (nullable = true)\n",
      " |-- relationship: string (nullable = true)\n",
      " |-- race: string (nullable = true)\n",
      " |-- gender: string (nullable = true)\n",
      " |-- capital-gain: integer (nullable = true)\n",
      " |-- capital-loss: integer (nullable = true)\n",
      " |-- hours-per-week: integer (nullable = true)\n",
      " |-- native-country: string (nullable = true)\n",
      " |-- income: string (nullable = true)\n",
      " |-- rawPrediction: vector (nullable = true)\n",
      " |-- probability: vector (nullable = true)\n",
      " |-- prediction: double (nullable = false)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "predictions.printSchema()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DataFrame[label: double, prediction: double, probability: vector, age: int, occupation: string]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "selected = predictions.select(\"label\", \"prediction\", \"probability\", \"age\", \"occupation\")\n",
    "display(selected)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-----+----------+--------------------+---+--------------+\n",
      "|label|prediction|         probability|age|    occupation|\n",
      "+-----+----------+--------------------+---+--------------+\n",
      "|  0.0|       0.0|[0.69373398473136...| 26|Prof-specialty|\n",
      "|  0.0|       0.0|[0.61163146131025...| 30|Prof-specialty|\n",
      "|  0.0|       0.0|[0.66067198545007...| 31|Prof-specialty|\n",
      "|  0.0|       0.0|[0.66129318473475...| 32|Prof-specialty|\n",
      "+-----+----------+--------------------+---+--------------+\n",
      "only showing top 4 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "selected.show(4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Area Under ROC 0.9023600168373183\n"
     ]
    }
   ],
   "source": [
    "from pyspark.ml.evaluation import BinaryClassificationEvaluator\n",
    "\n",
    "evaluator = BinaryClassificationEvaluator(rawPredictionCol=\"rawPrediction\")\n",
    "print('Area Under ROC', evaluator.evaluate(predictions))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'areaUnderROC'"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "evaluator.getMetricName()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "aggregationDepth: suggested depth for treeAggregate (>= 2). (default: 2)\n",
      "elasticNetParam: the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty. (default: 0.0)\n",
      "family: The name of family which is a description of the label distribution to be used in the model. Supported options: auto, binomial, multinomial (default: auto)\n",
      "featuresCol: features column name. (default: features, current: features)\n",
      "fitIntercept: whether to fit an intercept term. (default: True)\n",
      "labelCol: label column name. (default: label, current: label)\n",
      "lowerBoundsOnCoefficients: The lower bounds on coefficients if fitting under bound constrained optimization. The bound matrix must be compatible with the shape (1, number of features) for binomial regression, or (number of classes, number of features) for multinomial regression. (undefined)\n",
      "lowerBoundsOnIntercepts: The lower bounds on intercepts if fitting under bound constrained optimization. The bounds vector size must beequal with 1 for binomial regression, or the number oflasses for multinomial regression. (undefined)\n",
      "maxIter: max number of iterations (>= 0). (default: 100, current: 10)\n",
      "predictionCol: prediction column name. (default: prediction)\n",
      "probabilityCol: Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities. (default: probability)\n",
      "rawPredictionCol: raw prediction (a.k.a. confidence) column name. (default: rawPrediction)\n",
      "regParam: regularization parameter (>= 0). (default: 0.0)\n",
      "standardization: whether to standardize the training features before fitting the model. (default: True)\n",
      "threshold: Threshold in binary classification prediction, in range [0, 1]. If threshold and thresholds are both set, they must match.e.g. if threshold is p, then thresholds must be equal to [1-p, p]. (default: 0.5)\n",
      "thresholds: Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0, excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold. (undefined)\n",
      "tol: the convergence tolerance for iterative algorithms (>= 0). (default: 1e-06)\n",
      "upperBoundsOnCoefficients: The upper bounds on coefficients if fitting under bound constrained optimization. The bound matrix must be compatible with the shape (1, number of features) for binomial regression, or (number of classes, number of features) for multinomial regression. (undefined)\n",
      "upperBoundsOnIntercepts: The upper bounds on intercepts if fitting under bound constrained optimization. The bound vector size must be equal with 1 for binomial regression, or the number of classes for multinomial regression. (undefined)\n",
      "weightCol: weight column name. If this is not set or empty, we treat all instance weights as 1.0. (undefined)\n"
     ]
    }
   ],
   "source": [
    "print(lr.explainParams())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.ml.tuning import ParamGridBuilder, CrossValidator\n",
    "\n",
    "# Create ParamGrid for Cross Validation\n",
    "paramGrid = (ParamGridBuilder()\n",
    "             .addGrid(lr.regParam, [0.01, 0.5, 2.0])\n",
    "             .addGrid(lr.elasticNetParam, [0.0, 0.5, 1.0])\n",
    "             .addGrid(lr.maxIter, [1, 5, 10])\n",
    "             .build())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "cv = CrossValidator(estimator=lr, estimatorParamMaps=paramGrid, evaluator=evaluator, numFolds=5)\n",
    "\n",
    "# Run cross validations\n",
    "cvModel = cv.fit(train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Area Under ROC 0.9014836720704358\n"
     ]
    }
   ],
   "source": [
    "predictions = cvModel.transform(test)\n",
    "print('Area Under ROC', evaluator.evaluate(predictions))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Decision Trees"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.ml.classification import DecisionTreeClassifier\n",
    "\n",
    "dt = DecisionTreeClassifier(featuresCol = 'features', labelCol = 'label', maxDepth = 3)\n",
    "dtModel = dt.fit(train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "numNodes =  15\n",
      "depth =  3\n"
     ]
    }
   ],
   "source": [
    "print(\"numNodes = \", dtModel.numNodes)\n",
    "print(\"depth = \", dtModel.depth)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "root\n",
      " |-- label: double (nullable = false)\n",
      " |-- features: vector (nullable = true)\n",
      " |-- age: integer (nullable = true)\n",
      " |-- workclass: string (nullable = true)\n",
      " |-- fnlwgt: integer (nullable = true)\n",
      " |-- education: string (nullable = true)\n",
      " |-- educational-num: integer (nullable = true)\n",
      " |-- marital-status: string (nullable = true)\n",
      " |-- occupation: string (nullable = true)\n",
      " |-- relationship: string (nullable = true)\n",
      " |-- race: string (nullable = true)\n",
      " |-- gender: string (nullable = true)\n",
      " |-- capital-gain: integer (nullable = true)\n",
      " |-- capital-loss: integer (nullable = true)\n",
      " |-- hours-per-week: integer (nullable = true)\n",
      " |-- native-country: string (nullable = true)\n",
      " |-- income: string (nullable = true)\n",
      " |-- rawPrediction: vector (nullable = true)\n",
      " |-- probability: vector (nullable = true)\n",
      " |-- prediction: double (nullable = false)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "predictions = dtModel.transform(test)\n",
    "predictions.printSchema()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DataFrame[label: double, prediction: double, probability: vector, age: int, occupation: string]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "selected = predictions.select(\"label\", \"prediction\", \"probability\", \"age\", \"occupation\")\n",
    "display(selected)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.6750538342550954"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from pyspark.ml.evaluation import BinaryClassificationEvaluator\n",
    "\n",
    "evaluator = BinaryClassificationEvaluator()\n",
    "evaluator.evaluate(predictions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.ml.tuning import ParamGridBuilder, CrossValidator\n",
    "paramGrid = (ParamGridBuilder()\n",
    "             .addGrid(dt.maxDepth, [1, 2, 6, 10])\n",
    "             .addGrid(dt.maxBins, [20, 40, 80])\n",
    "             .build())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create 5-fold CrossValidator\n",
    "cv = CrossValidator(estimator=dt, estimatorParamMaps=paramGrid, evaluator=evaluator, numFolds=5)\n",
    "\n",
    "# Run cross validations\n",
    "cvModel = cv.fit(train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "numNodes =  561\n",
      "depth =  10\n"
     ]
    }
   ],
   "source": [
    "print(\"numNodes = \", cvModel.bestModel.numNodes)\n",
    "print(\"depth = \", cvModel.bestModel.depth)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.7930449475013471"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predictions = cvModel.transform(test)\n",
    "evaluator.evaluate(predictions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DataFrame[label: double, prediction: double, probability: vector, age: int, occupation: string]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "selected = predictions.select(\"label\", \"prediction\", \"probability\", \"age\", \"occupation\")\n",
    "display(selected)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-----+----------+--------------------+---+--------------+\n",
      "|label|prediction|         probability|age|    occupation|\n",
      "+-----+----------+--------------------+---+--------------+\n",
      "|  0.0|       0.0|[0.91479820627802...| 26|Prof-specialty|\n",
      "|  0.0|       0.0|[0.79166666666666...| 30|Prof-specialty|\n",
      "|  0.0|       0.0|[0.79166666666666...| 31|Prof-specialty|\n",
      "+-----+----------+--------------------+---+--------------+\n",
      "only showing top 3 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "selected.show(3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Random Forest"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "root\n",
      " |-- label: double (nullable = false)\n",
      " |-- features: vector (nullable = true)\n",
      " |-- age: integer (nullable = true)\n",
      " |-- workclass: string (nullable = true)\n",
      " |-- fnlwgt: integer (nullable = true)\n",
      " |-- education: string (nullable = true)\n",
      " |-- educational-num: integer (nullable = true)\n",
      " |-- marital-status: string (nullable = true)\n",
      " |-- occupation: string (nullable = true)\n",
      " |-- relationship: string (nullable = true)\n",
      " |-- race: string (nullable = true)\n",
      " |-- gender: string (nullable = true)\n",
      " |-- capital-gain: integer (nullable = true)\n",
      " |-- capital-loss: integer (nullable = true)\n",
      " |-- hours-per-week: integer (nullable = true)\n",
      " |-- native-country: string (nullable = true)\n",
      " |-- income: string (nullable = true)\n",
      " |-- rawPrediction: vector (nullable = true)\n",
      " |-- probability: vector (nullable = true)\n",
      " |-- prediction: double (nullable = false)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from pyspark.ml.classification import RandomForestClassifier\n",
    "\n",
    "rf = RandomForestClassifier(featuresCol = 'features', labelCol = 'label')\n",
    "rfModel = rf.fit(train)\n",
    "predictions = rfModel.transform(test)\n",
    "predictions.printSchema()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DataFrame[label: double, prediction: double, probability: vector, age: int, occupation: string]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "selected = predictions.select(\"label\", \"prediction\", \"probability\", \"age\", \"occupation\")\n",
    "display(selected)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-----+----------+--------------------+---+--------------+\n",
      "|label|prediction|         probability|age|    occupation|\n",
      "+-----+----------+--------------------+---+--------------+\n",
      "|  0.0|       0.0|[0.63089701288376...| 26|Prof-specialty|\n",
      "|  0.0|       0.0|[0.58219399032986...| 30|Prof-specialty|\n",
      "|  0.0|       0.0|[0.56792881180860...| 31|Prof-specialty|\n",
      "+-----+----------+--------------------+---+--------------+\n",
      "only showing top 3 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "selected.show(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8874404939215993"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "evaluator = BinaryClassificationEvaluator()\n",
    "evaluator.evaluate(predictions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "paramGrid = (ParamGridBuilder()\n",
    "             .addGrid(rf.maxDepth, [2, 4, 6])\n",
    "             .addGrid(rf.maxBins, [20, 60])\n",
    "             .addGrid(rf.numTrees, [5, 20])\n",
    "             .build())\n",
    "\n",
    "# Create 5-fold CrossValidator\n",
    "cv = CrossValidator(estimator=rf, estimatorParamMaps=paramGrid, evaluator=evaluator, numFolds=5)\n",
    "\n",
    "# Run cross validations.  This can take about 6 minutes since it is training over 20 trees!\n",
    "cvModel = cv.fit(train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8950399552714288"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predictions = cvModel.transform(test)\n",
    "evaluator.evaluate(predictions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DataFrame[label: double, prediction: double, probability: vector, age: int, occupation: string]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "selected = predictions.select(\"label\", \"prediction\", \"probability\", \"age\", \"occupation\")\n",
    "display(selected)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-----+----------+--------------------+---+--------------+\n",
      "|label|prediction|         probability|age|    occupation|\n",
      "+-----+----------+--------------------+---+--------------+\n",
      "|  0.0|       0.0|[0.70195216177776...| 26|Prof-specialty|\n",
      "|  0.0|       0.0|[0.64177635600230...| 30|Prof-specialty|\n",
      "|  0.0|       0.0|[0.63093587802395...| 31|Prof-specialty|\n",
      "+-----+----------+--------------------+---+--------------+\n",
      "only showing top 3 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "selected.show(3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Make Predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8969236717684712"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bestModel = cvModel.bestModel\n",
    "final_predictions = bestModel.transform(df)\n",
    "evaluator.evaluate(final_predictions)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_python3",
   "language": "python",
   "name": "conda_python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
